-
Notifications
You must be signed in to change notification settings - Fork 60
Add unique prefix - increasing counter #217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
I am trying to figure out lint errors. When i run it locally they all seemed to have passed. :) ruff check --fix tests/unit/dataset/test_synthetic.py |
End to end test: Ran following command for inference server running llama command: ` VLLM output: |
📦 Build Artifacts Available |
|
if prompt_tokens <= 0: | ||
return "" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prompt_tokens should never be less than 1. This is either redundant or there is an error in the sampling code.
"prompt": self._create_prompt( | ||
prompt_tokens, start_index, self.request_counter | ||
), | ||
"prompt_tokens_count": prompt_tokens, | ||
"output_tokens_count": output_tokens, | ||
} | ||
|
||
def _create_prompt(self, prompt_tokens: int, start_index: int) -> str: | ||
def _create_prompt( | ||
self, prompt_tokens: int, start_index: int, request_id: int | ||
) -> str: | ||
""" | ||
Create a prompt with unique prefix to prevent vLLM prefix caching. | ||
Args: | ||
prompt_tokens: Target number of tokens for the prompt | ||
start_index: Starting position in the text corpus | ||
request_id: Unique identifier for this request (used as prefix) | ||
Returns: | ||
Generated prompt string with unique prefix | ||
""" | ||
if prompt_tokens <= 0: | ||
return "" | ||
return f"{request_id}: " | ||
|
||
unique_prefix = f"{request_id}: " | ||
|
||
# Calculate how many tokens the prefix uses | ||
prefix_tokens = len(self.processor.tokenize(unique_prefix)) | ||
|
||
# Adjust target tokens to account for the prefix | ||
remaining_tokens = max(1, prompt_tokens - prefix_tokens) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like how unique_prefix
is an arbitrary number of tokens. If prompt_tokens
is too low some or all requests will have len(unique_prefix) > len(prefix_tokens)
. It would be better if we make the prefix length always one token. Easy way to do this is iterate over the tokenizer vocab. Something like:
prefix_iter = iter(t for t in self.processor.get_vocab())
...
unique_prefix = next(it)
def test_validation_positive_values(self): | ||
"""Test that negative values are rejected.""" | ||
with pytest.raises(ValueError): | ||
SyntheticDatasetConfig(prompt_tokens=-1, output_tokens=20) | ||
|
||
with pytest.raises(ValueError): | ||
SyntheticDatasetConfig(prompt_tokens=20, output_tokens=-1) | ||
|
||
with pytest.raises(ValueError): | ||
SyntheticDatasetConfig(prompt_tokens=20, output_tokens=10, samples=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrong bounds
def test_validation_positive_values(self): | |
"""Test that negative values are rejected.""" | |
with pytest.raises(ValueError): | |
SyntheticDatasetConfig(prompt_tokens=-1, output_tokens=20) | |
with pytest.raises(ValueError): | |
SyntheticDatasetConfig(prompt_tokens=20, output_tokens=-1) | |
with pytest.raises(ValueError): | |
SyntheticDatasetConfig(prompt_tokens=20, output_tokens=10, samples=-1) | |
def test_validation_nonpositive_values(self): | |
"""Test that non-positive values are rejected.""" | |
with pytest.raises(ValueError): | |
SyntheticDatasetConfig(prompt_tokens=0, output_tokens=20) | |
with pytest.raises(ValueError): | |
SyntheticDatasetConfig(prompt_tokens=20, output_tokens=0) | |
with pytest.raises(ValueError): | |
SyntheticDatasetConfig(prompt_tokens=20, output_tokens=10, samples=0) |
@pytest.fixture | ||
def tokenizer(self): | ||
"""Fixture to provide a tokenizer for testing.""" | ||
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tokenizer need to be mocked; unit tests should not need to download data to function.
The SyntheticTextItemsGenerator was generating prompts that could trigger vLLM's automatic prefix caching, leading to hitting the prefix cache up to 80% in some cases during the performance benchmarking.
Implemented unique prefix injection to guarantee 0% prefix cache hit rate while maintaining realistic prompt characteristics.
Test:
Performing some tests on the H200 target accelerator to confirm the fix.